{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nlpaug.augmenter.char as nac\n",
    "import nlpaug.augmenter.word as naw\n",
    "import nlpaug.flow as naf\n",
    "\n",
    "from nlpaug.util import Method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = 'The quick brown'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Character Augmenter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nlpaug.augmenter.char import CharAugmenter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'ThE quick brown'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class CustomCharAug(CharAugmenter):\n",
    "    def __init__(self, name='CustChar_Aug', min_char=2, aug_char_min=1, aug_char_max=10, aug_char_p=0.3,\n",
    "                 aug_word_min=1, aug_word_max=10, aug_word_p=0.3, tokenizer=None, reverse_tokenizer=None,\n",
    "                 stopwords=None, verbose=0, stopwords_regex=None):\n",
    "        super().__init__(\n",
    "            name=name, action=\"substitute\", min_char=min_char, aug_char_min=aug_char_min, \n",
    "                aug_char_max=aug_char_max, aug_char_p=aug_char_p, aug_word_min=aug_word_min, \n",
    "                aug_word_max=aug_word_max, aug_word_p=aug_word_p, tokenizer=tokenizer, \n",
    "                reverse_tokenizer=reverse_tokenizer, stopwords=stopwords, device='cpu', \n",
    "                verbose=verbose, stopwords_regex=stopwords_regex)\n",
    "\n",
    "        self.model = self.get_model()\n",
    "\n",
    "    def substitute(self, data):\n",
    "        results = []\n",
    "        # Tokenize a text (e.g. The quick brown fox jumps over the lazy dog) to tokens (e.g. ['The', 'quick', ...])\n",
    "        tokens = self.tokenizer(data)\n",
    "        # Get target tokens\n",
    "        aug_word_idxes = self._get_aug_idxes(tokens, self.aug_word_min, self.aug_word_max, self.aug_word_p, Method.WORD)\n",
    "        \n",
    "        for token_i, token in enumerate(tokens):\n",
    "            # Do not augment if it is not the target\n",
    "            if token_i not in aug_word_idxes:\n",
    "                results.append(token)\n",
    "                continue\n",
    "\n",
    "            result = ''\n",
    "            chars = self.token2char(token)\n",
    "            # Get target characters\n",
    "            aug_char_idxes = self._get_aug_idxes(chars, self.aug_char_min, self.aug_char_max, self.aug_char_p,\n",
    "                                                 Method.CHAR)\n",
    "            if aug_char_idxes is None:\n",
    "                results.append(token)\n",
    "                continue\n",
    "\n",
    "            for char_i, char in enumerate(chars):\n",
    "                if char_i not in aug_char_idxes:\n",
    "                    result += char\n",
    "                    continue\n",
    "                \n",
    "                # Augmentation: Expect return from 'self.model.predict' is a list of possible outcome. Otherwise, no need to sample\n",
    "                result += self.sample(self.model.predict(chars[char_i]), 1)[0]\n",
    "\n",
    "            results.append(result)\n",
    "\n",
    "        return self.reverse_tokenizer(results)\n",
    "\n",
    "    def get_model(self):\n",
    "        return TempModel()\n",
    "    \n",
    "# Define your model with \"predict\" function\n",
    "class TempModel:\n",
    "    def __init__(self):\n",
    "        self.model = {\n",
    "            'T': ['t'],\n",
    "            'h': ['H'],\n",
    "            'e': ['E'],\n",
    "            'q': ['Q'],\n",
    "            'u': ['U'],\n",
    "            'i': ['I'],\n",
    "            'c': ['C'],\n",
    "            'k': ['K'],\n",
    "            'b': ['B'],\n",
    "            'r': ['R'],\n",
    "            'o': ['O'],\n",
    "            'w': ['W'],\n",
    "            'n': ['n']\n",
    "        }\n",
    "\n",
    "    def predict(self, x):\n",
    "        # you can implement your own logic \n",
    "        if x in self.model:\n",
    "            return self.model[x]\n",
    "        \n",
    "        return [x]\n",
    "\n",
    "aug = CustomCharAug()\n",
    "aug.augment(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Word Augmenter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nlpaug.augmenter.word import WordAugmenter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'The Custom2 quick brown'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class CustomWordAug(WordAugmenter):\n",
    "    def __init__(self, name='CustomWord_Aug', aug_min=1, aug_max=10, \n",
    "                 aug_p=0.3, stopwords=None, tokenizer=None, reverse_tokenizer=None, \n",
    "                 device='cpu', verbose=0, stopwords_regex=None):\n",
    "        super(CustomWordAug, self).__init__(\n",
    "            action='insert', name=name, aug_min=aug_min, aug_max=aug_max, \n",
    "                 aug_p=aug_p, stopwords=stopwords, tokenizer=tokenizer, reverse_tokenizer=reverse_tokenizer, \n",
    "                 device=device, verbose=0, stopwords_regex=stopwords_regex)\n",
    "        \n",
    "        self.model = self.get_model()\n",
    "\n",
    "    def insert(self, data):\n",
    "        \"\"\"\n",
    "        :param tokens: list of token\n",
    "        :return: list of token\n",
    "        \"\"\"\n",
    "        tokens = self.tokenizer(data)\n",
    "        results = tokens.copy()\n",
    "\n",
    "        aug_idexes = self._get_random_aug_idxes(tokens)\n",
    "        if aug_idexes is None:\n",
    "            return data\n",
    "        aug_idexes.sort(reverse=True)\n",
    "\n",
    "        for aug_idx in aug_idexes:\n",
    "            new_word = self.sample(self.model, 1)[0]\n",
    "            results.insert(aug_idx, new_word)\n",
    "\n",
    "        return self.reverse_tokenizer(results)\n",
    "    \n",
    "    def get_model(self):\n",
    "        return ['Custom1', 'Custom2']\n",
    "        \n",
    "        \n",
    "\n",
    "aug = CustomWordAug()\n",
    "aug.augment(text)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
